import json
import time
import uuid

import paho.mqtt.client as mqtt
import ssl, base64, hmac
from hashlib import sha1
import tempfile
import os

from unilabos.config.config import MQConfig
from unilabos.app.controler import devices, job_add
from unilabos.app.model import JobAddReq, JobAddResp
from unilabos.utils import logger
from unilabos.utils.type_check import TypeEncoder


class MQTTClient:
    mqtt_disable = True

    def __init__(self):
        self.mqtt_disable = not MQConfig.lab_id
        self.client_id = f"{MQConfig.group_id}@@@{MQConfig.lab_id}{uuid.uuid4()}"
        self.client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id=self.client_id, protocol=mqtt.MQTTv5)
        self._setup_callbacks()

    def _setup_callbacks(self):
        self.client.on_log = self._on_log
        self.client.on_connect = self._on_connect
        self.client.on_message = self._on_message
        self.client.on_disconnect = self._on_disconnect

    def _on_log(self, client, userdata, level, buf):
        logger.info(f"[MQTT] log: {buf}")

    def _on_connect(self, client, userdata, flags, rc, properties=None):
        logger.info("[MQTT] Connected with result code " + str(rc))
        client.subscribe(f"labs/{MQConfig.lab_id}/job/start/", 0)
        isok, data = devices()
        if not isok:
            logger.error("[MQTT] on_connect ErrorHostNotInit")
            return

    def _on_message(self, client, userdata, msg):
        logger.info("[MQTT] on_message<<<< " + msg.topic + " " + str(msg.payload))
        try:
            payload_str = msg.payload.decode("utf-8")
            payload_json = json.loads(payload_str)
            logger.debug(f"Topic: {msg.topic}")
            logger.debug("Payload:", json.dumps(payload_json, indent=2, ensure_ascii=False))
            if msg.topic == f"labs/{MQConfig.lab_id}/job/start/":
                logger.debug("job_add", type(payload_json), payload_json)
                job_req = JobAddReq.model_validate(payload_json)
                data = job_add(job_req)
                return JobAddResp(data=data)

        except json.JSONDecodeError as e:
            logger.error(f"[MQTT] JSON 解析错误: {e}")
            logger.error(f"[MQTT] Raw message: {msg.payload}")
        except Exception as e:
            logger.error(f"[MQTT] 处理消息时出错: {e}")

    def _on_disconnect(self, client, userdata, rc, reasonCode=None, properties=None):
        if rc != 0:
            logger.error(f"[MQTT] Unexpected disconnection {rc}")

    def _setup_ssl_context(self):
        temp_files = []
        try:
            with tempfile.NamedTemporaryFile(mode="w", delete=False) as ca_temp:
                ca_temp.write(MQConfig.ca_content)
                temp_files.append(ca_temp.name)

            with tempfile.NamedTemporaryFile(mode="w", delete=False) as cert_temp:
                cert_temp.write(MQConfig.cert_content)
                temp_files.append(cert_temp.name)

            with tempfile.NamedTemporaryFile(mode="w", delete=False) as key_temp:
                key_temp.write(MQConfig.key_content)
                temp_files.append(key_temp.name)

            context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
            context.load_verify_locations(cafile=temp_files[0])
            context.load_cert_chain(certfile=temp_files[1], keyfile=temp_files[2])
            self.client.tls_set_context(context)
        finally:
            for temp_file in temp_files:
                try:
                    os.unlink(temp_file)
                except:
                    pass

    def start(self):
        if self.mqtt_disable:
            logger.warning("MQTT is disabled, skipping connection.")
            return
        userName = f"Signature|{MQConfig.access_key}|{MQConfig.instance_id}"
        password = base64.b64encode(
            hmac.new(MQConfig.secret_key.encode(), self.client_id.encode(), sha1).digest()
        ).decode()

        self.client.username_pw_set(userName, password)
        self._setup_ssl_context()

        # 创建连接线程
        def connect_thread_func():
            try:
                self.client.connect(MQConfig.broker_url, MQConfig.port, 60)
                self.client.loop_start()

                # 添加连接超时检测
                max_attempts = 5
                attempt = 0
                while not self.client.is_connected() and attempt < max_attempts:
                    logger.info(
                        f"[MQTT] 正在连接到 {MQConfig.broker_url}:{MQConfig.port}，尝试 {attempt+1}/{max_attempts}"
                    )
                    time.sleep(3)
                    attempt += 1

                if self.client.is_connected():
                    logger.info(f"[MQTT] 已成功连接到 {MQConfig.broker_url}:{MQConfig.port}")
                else:
                    logger.error(f"[MQTT] 连接超时，可能是账号密码错误或网络问题")
                    self.client.loop_stop()
            except Exception as e:
                logger.error(f"[MQTT] 连接失败: {str(e)}")

        connect_thread_func()
        # connect_thread = threading.Thread(target=connect_thread_func)
        # connect_thread.daemon = True
        # connect_thread.start()

    def stop(self):
        if self.mqtt_disable:
            return
        self.client.disconnect()
        self.client.loop_stop()

    def publish_device_status(self, device_status: dict, device_id, property_name):
        # status = device_status.get(device_id, {})
        if self.mqtt_disable:
            return
        status = {"data": device_status.get(device_id, {}), "device_id": device_id}
        address = f"labs/{MQConfig.lab_id}/devices"
        self.client.publish(address, json.dumps(status), qos=2)
        logger.critical(f"Device status published: address: {address}, {status}")

    def publish_job_status(self, feedback_data: dict, job_id: str, status: str):
        if self.mqtt_disable:
            return
        jobdata = {"job_id": job_id, "data": feedback_data, "status": status}
        self.client.publish(f"labs/{MQConfig.lab_id}/job/list/", json.dumps(jobdata), qos=2)
    
    def publish_registry(self, device_id: str, device_info: dict):
        if self.mqtt_disable:
            return
        address = f"labs/{MQConfig.lab_id}/registry/"
        registry_data = json.dumps({device_id: device_info}, ensure_ascii = False, cls = TypeEncoder)
        self.client.publish(address, registry_data, qos=2)
        logger.debug(f"Registry data published: address: {address}, {registry_data}")

    def publish_actions(self, action_id: str, action_info: dict):
        if self.mqtt_disable:
            return
        address = f"labs/{MQConfig.lab_id}/actions/"
        action_type_name = action_info["title"]
        action_info["title"] = action_id
        action_data = json.dumps({action_type_name: action_info}, ensure_ascii=False)
        self.client.publish(address, action_data, qos=2)
        logger.debug(f"Action data published: address: {address}, {action_data}")


mqtt_client = MQTTClient()

if __name__ == "__main__":
    mqtt_client.start()
